

import torch 

# args
# checkpoint_version
#iteration
#model

sd = torch.load("/share/project/ayl/FlagScale/tmp/iter_0030000/mp_rank_00/model_optim_rng.pt", map_location="cpu")
# for k, v in sd["model"].items():
#     print(k, v.shape)

print(sd["model"]["language_model"]["encoder"].keys())

